#!/usr/bin/env python3
# /// script
# requires-python = ">=3.10"
# dependencies = [
#     "GitPython",
#     "PyGithub",
#     "rich-click",
#     "packaging",
#     "rich",
# ]
# ///

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

# This tool is based on the Spark merge_spark_pr script:
# https://github.com/apache/spark/blob/master/dev/merge_spark_pr.py

from __future__ import annotations

import re
import sys
from collections import Counter, defaultdict
from typing import TYPE_CHECKING

import git
import rich_click as click
from github import Auth, Github
from packaging import version
from rich.console import Console
from rich.progress import Progress

if TYPE_CHECKING:
    from github.Issue import Issue

GIT_COMMIT_FIELDS = ["id", "author_name", "author_email", "date", "subject", "body"]
GIT_LOG_FORMAT = "%x1f".join(["%h", "%an", "%ae", "%ad", "%s", "%b"]) + "%x1e"
pr_title_re = re.compile(r".*\((#[0-9]{1,6})\)$")

STATUS_COLOR_MAP = {
    "Closed": "yellow",
    "Merged": "green",
    "Open": "red",
}

DEFAULT_SECTION_NAME = "Uncategorized"


def get_commits_between(
    repo,
    previous_version,
    target_version,
    files=None,
):
    log_args = [f"--format={GIT_LOG_FORMAT}", previous_version + ".." + target_version]
    if files:
        log_args.append("--")
        log_args.append(" ".join(files))
    log = repo.git.log(*log_args)
    log = log.strip("\n\x1e").split("\x1e")
    log = [row.strip().split("\x1f") for row in log]
    return [dict(zip(GIT_COMMIT_FIELDS, row)) for row in log]


def style_issue_status(status):
    if status in STATUS_COLOR_MAP:
        return click.style(status[:6].ljust(6), STATUS_COLOR_MAP[status])
    return status[:6].ljust(6)


def get_issue_type(issue):
    label_prefix = "type:"
    issue_type = DEFAULT_SECTION_NAME
    if issue.labels:
        for label in issue.labels:
            if label.name.startswith(label_prefix):
                return label.name.replace(label_prefix, "").strip()
            if label.name == "changelog:skip":
                return "(skip)"
    return issue_type


def build_main_commits_cache(repo: git.Repo, issue_numbers: set[int]) -> dict[int, str]:
    """Build a cache of PR number -> main branch commit SHA with a single git log operation"""
    cache = {}

    if not issue_numbers:
        return cache

    try:
        # Single git log to get all main branch commits
        # Use a reasonable range to avoid scanning entire history
        try:
            # Try to get commits since last major version (should cover most PR ranges)
            log_output = repo.git.log("origin/main", "--format=%H %s", "--since=1 year ago")
        except Exception:
            # Fallback to last 2000 commits if date fails
            log_output = repo.git.log("origin/main", "--format=%H %s", "-2000")

        # Use regex to find all PR numbers in commit messages at once
        import re
        pr_pattern = r"\(#(\d+)\)$"  # PR number at end of commit message

        for commit_line in log_output.splitlines():
            if not commit_line:
                continue

            # Find PR number at end of commit message
            match = re.search(pr_pattern, commit_line)
            if match:
                pr_number = int(match.group(1))
                if pr_number in issue_numbers:
                    commit_sha = commit_line.split(" ")[0]
                    cache[pr_number] = commit_sha

    except Exception:
        # Fallback to empty cache if git operation fails
        pass

    return cache


def get_commit_in_main_associated_with_pr(repo: git.Repo, issue: Issue, main_commits_cache: dict[int, str]) -> str | None:
    """For a PR, find the associated merged commit & return its SHA"""
    if issue.pull_request:
        # Use cache (should be pre-populated)
        return main_commits_cache.get(issue.number)
    return None


def build_cherrypicked_cache(repo: git.Repo, issue_numbers: list[int], previous_version: str | None = None) -> dict[int, bool]:
    """Build a cache of which issues are cherry-picked by doing a single git log operation"""
    cache = {num: False for num in issue_numbers}

    if not issue_numbers:
        return cache

    # Get all commits in range and process them
    log_args = ["--format=%H %s"]
    if previous_version:
        log_args.append(previous_version + "..")

    try:
        log_output = repo.git.log(*log_args)
        # Use regex to find all PR numbers in the entire log output at once
        import re
        pr_pattern = r"\(#(\d+)\)"

        for commit_line in log_output.splitlines():
            # Find all PR numbers in this commit
            matches = re.findall(pr_pattern, commit_line)
            for match in matches:
                issue_num = int(match)
                if issue_num in cache:
                    cache[issue_num] = True
    except Exception:
        # Fallback to individual checks if batch fails
        pass

    return cache


def is_cherrypicked(repo: git.Repo, issue: Issue, previous_version: str | None = None) -> bool:
    """Check if a given issue is cherry-picked in the current branch or not"""
    log_args = ["--format=%H %s", f"--grep=(#{issue.number})"]
    if previous_version:
        log_args.append(previous_version + "..")
    log_output = repo.git.log(*log_args)

    for commit_line in log_output.splitlines():
        # Check if the PR number exists in the commit message (handles cherry-picked commits with multiple PR numbers)
        if commit_line and f"(#{issue.number})" in commit_line:
            return True
    return False


def is_pr(issue: Issue) -> bool:
    return "apache/airflow/pull/" in issue.html_url


def files_touched(repo, commit_id: str) -> list[str]:
    real_commit = repo.commit(commit_id)
    return real_commit.stats.files.keys()


def is_core_commit(files: list[str]) -> bool:
    # We list out things that _aren't_ core files,
    # and we want to know if the commit changes anything
    # outside of these designated-non-core files

    non_core_files = [
        # Providers
        "airflow/providers/",
        "airflow/provider.yaml.schema.json",
        "docs/apache-airflow-providers-",
        # chart
        "chart/",
        "docs/helm-chart/",
        "clients",
        # non-released docs
        "COMMITTERS.rst",
        "contributing_docs/",
        "INTHEWILD.md",
        "INSTALL",
        "README.md",
        "images/",
        "codecov.yml",
        "kubernetes-tests/",
        ".github/",
        ".pre-commit-config.yaml",
        "yamllint-config.yml",
        ".markdownlint.yml",
        "tests/",
        "dev/",
        "Dockerfile",
        "Dockerfile.ci",
        ".hadolint.yaml",
        "scripts/",
        "docs/build_docs.py",
        "docs/start_doc_server.sh",
        ".devcontainer/",
        "docker-context-files/",
        # Misc
        ".dockerignore",
        "docs/spelling_wordlist.txt",
        "docs/integration-logos/*",
        "docs/exts/",
        "docs/docker-stack",
        "docs/README.rst",
        "docker-tests/",
        "kubernetes-tests/",
        "helm-tests/",
        ".asf.yaml",
        ".mailmap",
        "breeze-legacy",
        "breeze-complete",
        ".rat-excludes",
        ".gitattributes",
        ".gitpod.yml",
        "generated/",
    ]

    for file in files:
        for ignore in non_core_files:
            if file.startswith(ignore):
                break
            # Handle renaming. Form: {old_name => new_name}somename.py
            if file.startswith("{"):
                new_files = file[1:].split(" => ")
                if any(n.strip().startswith(ignore) for n in new_files):
                    break
        else:
            return True
    return False


def print_changelog(sections):
    for section, lines in sections.items():
        if section == "(skip)":
            continue
        print(section)
        print('"' * len(section))
        for line in lines:
            print("-", line)
        print()


@click.group()
def cli():
    r"""
    This tool should be used by Airflow Release Manager to verify what GitHub issues
     were merged in the current working branch.

        airflow-github compare <target_version> <github_token>
    """


@cli.command(short_help="Compare a GitHub target version against git merges")
@click.argument("target_version")
@click.argument("github-token", envvar="GITHUB_TOKEN")
@click.option(
    "--previous-version",
    "previous_version",
    help="Specify the previous tag on the working branch to limit"
    " searching for few commits to find the cherry-picked commits",
)
@click.option("--unmerged", "show_uncherrypicked_only", help="Show unmerged PRs only", is_flag=True)
@click.option("--show-commits", help="Show commit SHAs (default: on, off when --unmerged)", is_flag=True, default=None)
def compare(target_version, github_token, previous_version=None, show_uncherrypicked_only=False, show_commits=None):
    # Set smart defaults
    if show_commits is None:
        show_commits = not show_uncherrypicked_only  # Default off for --unmerged

    repo = git.Repo(".", search_parent_directories=True)

    github_handler = Github(auth=Auth.Token(github_token))

    # Fetch PRs and Issues separately, with merged PRs identified upfront
    merged_prs: list[Issue] = list(
        github_handler.search_issues(
            f'repo:apache/airflow milestone:"Airflow {target_version}" is:pull-request is:merged'
        )
    )
    closed_prs: list[Issue] = list(
        github_handler.search_issues(
            f'repo:apache/airflow milestone:"Airflow {target_version}" is:pull-request is:closed -is:merged'
        )
    )
    open_prs: list[Issue] = list(
        github_handler.search_issues(
            f'repo:apache/airflow milestone:"Airflow {target_version}" is:pull-request is:open'
        )
    )

    # Skip fetching issues if we only care about unmerged PRs
    if show_uncherrypicked_only:
        issues = []
    else:
        issues: list[Issue] = list(
            github_handler.search_issues(
                f'repo:apache/airflow milestone:"Airflow {target_version}" is:issue'
            )
        )

    # Create a merge status lookup
    pr_merge_status_cache = {}
    for pr in merged_prs:
        pr_merge_status_cache[pr.number] = True
    for pr in closed_prs:
        pr_merge_status_cache[pr.number] = False
    # Open PRs are neither merged nor closed, so we don't need to cache them

    milestone_issues = merged_prs + closed_prs + open_prs + issues

    num_cherrypicked = 0
    num_uncherrypicked = Counter()

    # :<18 says left align, pad to 18, :>6 says right align, pad to 6
    # :<50.50 truncates after 50 chars
    # !s forces as string
    if show_commits:
        formatstr = (
            "{number:>6} | {typ!s:<5} | {changelog!s:<13} | {status!s} "
            "| {title:<83.83} | {merged:<6} | {commit:>7.7} | {url}"
        )
        header_fields = {
            "number": "NUMBER",
            "typ": "TYPE",
            "changelog": "CHANGELOG",
            "status": "STATUS".ljust(6),
            "title": "TITLE",
            "merged": "CHERRY",
            "commit": "COMMIT",
            "url": "URL",
        }
    else:
        formatstr = (
            "{number:>6} | {typ!s:<5} | {changelog!s:<13} | {status!s} "
            "| {title:<95.95} | {merged:<6} | {url}"
        )
        header_fields = {
            "number": "NUMBER",
            "typ": "TYPE",
            "changelog": "CHANGELOG",
            "status": "STATUS".ljust(6),
            "title": "TITLE",
            "merged": "CHERRY",
            "commit": "",  # Not used
            "url": "URL",
        }

    print(formatstr.format(**header_fields))
    milestone_issues = sorted(
        milestone_issues, key=lambda x: x.closed_at if x.closed_at else x.created_at, reverse=True
    )

    # Build caches for performance optimization
    issue_numbers = [issue.number for issue in milestone_issues if is_pr(issue)]

    # Convert to set for O(1) lookups in cache building
    issue_numbers_set = set(issue_numbers)

    # Build all caches upfront with batch operations
    if show_commits:
        main_commits_cache = build_main_commits_cache(repo, issue_numbers_set)
    else:
        main_commits_cache = {}

    cherrypicked_cache = build_cherrypicked_cache(repo, issue_numbers, previous_version)

    for issue in milestone_issues:
        issue_is_pr = is_pr(issue)

        # Determine status - differentiate between Closed and Merged for PRs using cache
        if issue_is_pr and issue.state == "closed":
            is_merged = pr_merge_status_cache.get(issue.number, False)
            status = "Merged" if is_merged else "Closed"
        else:
            status = issue.state.capitalize()

        # Checks if commit was cherrypicked into branch using cache
        if issue_is_pr:
            is_cherry_picked = cherrypicked_cache.get(issue.number, False)
        else:
            is_cherry_picked = is_cherrypicked(repo, issue, previous_version)

        if is_cherry_picked:
            num_cherrypicked += 1
            if show_uncherrypicked_only:
                continue
            cherrypicked = click.style("Yes".ljust(6), "green")
        elif not issue_is_pr and show_uncherrypicked_only:
            # Don't show issues when looking for unmerged PRs
            continue
        elif issue_is_pr:
            num_uncherrypicked[status] += 1
            cherrypicked = click.style("No".ljust(6), "red")
        else:
            cherrypicked = ""

        fields = dict(
            number=issue.number,
            typ="PR" if issue_is_pr else "Issue",
            changelog=get_issue_type(issue) if issue_is_pr else "",
            status=style_issue_status(status),
            title=issue.title,
            url=issue.html_url,
        )

        # Only get commit info if we're showing commits
        if show_commits:
            commit_in_main = get_commit_in_main_associated_with_pr(repo, issue, main_commits_cache)
            fields["commit"] = commit_in_main if commit_in_main else ""

        print(formatstr.format(**fields, merged=cherrypicked))

    print(
        f"Commits on branch: {num_cherrypicked:d}, {sum(num_uncherrypicked.values()):d} "
        f"({dict(num_uncherrypicked)}) yet to be cherry-picked"
    )


@cli.command(short_help="Build a CHANGELOG grouped by GitHub Issue type")
@click.argument("previous_version")
@click.argument("target_version")
@click.argument("github-token", envvar="GITHUB_TOKEN")
@click.option("--disable-progress-bar", is_flag=True, help="Disable the progress bar")
def changelog(previous_version, target_version, github_token, disable_progress_bar):
    repo = git.Repo(".", search_parent_directories=True)
    # Get a list of issues/PRs that have been committed on the current branch.
    log = get_commits_between(repo, previous_version, target_version)

    if disable_progress_bar:
        print(f"Processing {len(log)} commits")

    gh = Github(auth=Auth.Token(github_token))
    gh_repo = gh.get_repo("apache/airflow")
    sections = defaultdict(list)
    with open("RELEASE_NOTES.rst") as file:
        existing_pr_nos = re.findall(r"\(#(\d+)\)", file.read())

    console = Console(width=180)

    with Progress(console=console, disable=disable_progress_bar) as progress:
        task = progress.add_task("Processing commits from changelog", total=len(log))
        for commit in progress.track(log, description="Processing commits from changelog"):
            tickets = pr_title_re.findall(commit["subject"])
            if tickets:
                issue = gh_repo.get_issue(number=int(tickets[0][1:]))
                issue_type = get_issue_type(issue)
                files = files_touched(repo, commit["id"])
                if str(issue.number) in existing_pr_nos:
                    print(f"Skipping {issue.number} as it's already in RELEASE_NOTES.rst")
                    continue
                if is_core_commit(files):
                    sections[issue_type].append(commit["subject"])
                progress.update(task, advance=1)
            else:
                sections[DEFAULT_SECTION_NAME].append(commit["subject"])
                progress.update(task, advance=1)

    console.print("\n")
    print_changelog(sections)


@cli.command(short_help="Find merged PRs that still need to be categorized for the changelog")
@click.argument("previous_version")
@click.argument("target_version")
@click.option("--show-skipped", is_flag=True)
@click.option("--show-files", is_flag=True)
@click.argument("github-token", envvar="GITHUB_TOKEN")
def needs_categorization(previous_version, target_version, show_skipped, show_files, github_token):
    repo = git.Repo(".", search_parent_directories=True)
    log = get_commits_between(repo, previous_version, target_version)

    gh = Github(auth=Auth.Token(github_token))
    gh_repo = gh.get_repo("apache/airflow")
    for commit in log:
        tickets = pr_title_re.findall(commit["subject"])
        if tickets:
            issue = gh_repo.get_issue(number=int(tickets[0][1:]))
            issue_type = get_issue_type(issue)
            if issue_type == DEFAULT_SECTION_NAME:
                files = files_touched(repo, commit["id"])
                if is_core_commit(files):
                    print(f"{commit['subject']}: {issue.html_url}")
                    if show_files:
                        for f in files:
                            print(f"\t{f}")
                elif show_skipped:
                    print(f"**** {commit['subject']}: Skipping - not a core commit")
                    if show_files:
                        for f in files:
                            print(f"\t{f}")
        else:
            print(f"Commit '{commit['id']}' is missing PR number: {commit['subject']}")


@cli.command(
    name="api-clients-policy",
    help="Compare two airflow core release tags and determine if API clients need to be released.",
)
@click.argument("previous_version")
@click.argument("target_version")
def api_clients_policy(previous_version, target_version):
    p_version = version.parse(previous_version)
    t_version = version.parse(target_version)

    if p_version.major != t_version.major:
        print("This is a major release, API clients should also be released.")
        return
    if p_version.minor != t_version.minor:
        print("This is a minor release, API clients should also be released.")
        return
    if p_version == t_version:
        print("Both versions are identical")
        return

    repo = git.Repo(".", search_parent_directories=True)
    log = get_commits_between(
        repo,
        previous_version,
        target_version,
        files=[f"{repo.working_dir}/airflow-core/src/airflow/api_fastapi/core_api/openapi/v2-rest-api-generated.yaml"],
    )

    clients_need_release = False
    for commit in log:
        if "update airflow version to" not in commit["subject"].lower():
            clients_need_release = True
            print(f"Commit '{commit['id']}' updated the OpenAPI spec, PR number: {commit['subject']}")

    if clients_need_release:
        print(f"API clients need to be released because the API spec has changed since '{previous_version}'")
    else:
        print(
            "API clients don't need to be released because the API spec hasn't changed between those two "
            "patch versions"
        )


if __name__ == "__main__":
    import doctest

    (failure_count, test_count) = doctest.testmod()
    if failure_count:
        sys.exit(-1)
    try:
        cli()
    except Exception:
        raise
